{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "os.environ['CUDA_VISIBLE_DEVICES'] = '2'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "# from path_explain.path_explainer_tf import PathExplainerTF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.constant([[2.0, 5.0], [6.0, 8.0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=17, shape=(2, 2), dtype=float32, numpy=\n",
       "array([[ 4., 10.],\n",
       "       [12., 16.]], dtype=float32)>"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with tf.GradientTape() as t:\n",
    "    t.watch(x)\n",
    "    y = tf.square(x)\n",
    "    z = tf.reduce_sum(y, axis=-1)\n",
    "t.gradient(z, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.models.Sequential()\n",
    "model.add(tf.keras.layers.Dense(1, input_shape=(2,)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[<tf.Variable 'dense/kernel:0' shape=(2, 1) dtype=float32, numpy=\n",
       " array([[-0.5324452 ],\n",
       "        [-0.42080152]], dtype=float32)>,\n",
       " <tf.Variable 'dense/bias:0' shape=(1,) dtype=float32, numpy=array([0.], dtype=float32)>]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "k = tf.constant([[200.0, 500.0], [6.0, 8.0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=63, shape=(2, 2), dtype=float32, numpy=\n",
       "array([[-0.5324452 , -0.42080152],\n",
       "       [-0.5324452 , -0.42080152]], dtype=float32)>"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with tf.GradientTape() as t:\n",
    "    t.watch(k)\n",
    "    y = model(k)\n",
    "t.gradient(y, k)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.constant([[2.0, 5.0, 3.0], [6.0, 8.0, 1.0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = tf.keras.models.Sequential()\n",
    "model.add(tf.keras.layers.Dense(10, input_shape=(3,), activation='softmax'))\n",
    "model.add(tf.keras.layers.Dense(1))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 98,
   "metadata": {},
   "outputs": [],
   "source": [
    "with tf.GradientTape() as second_order_tape:\n",
    "    second_order_tape.watch(x)\n",
    "    with tf.GradientTape() as first_order_tape:\n",
    "        first_order_tape.watch(x)\n",
    "        z = model(x)\n",
    "    grad = first_order_tape.gradient(z, x)\n",
    "jacobian = second_order_tape.jacobian(grad, x)\n",
    "# second_grad = second_order_tape.gradient(indexed_grad, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 104,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[[-0.02110989,  0.01286735,  0.01510115],\n",
       "        [ 0.01286735, -0.0107625 , -0.00567085],\n",
       "        [ 0.01510115, -0.00567085, -0.01363467]],\n",
       "\n",
       "       [[ 0.02816404, -0.01690738, -0.01374249],\n",
       "        [-0.01690738,  0.01197595,  0.00869031],\n",
       "        [-0.01374249,  0.00869031,  0.0086968 ]]], dtype=float32)"
      ]
     },
     "execution_count": 104,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jacobian.numpy()[np.arange(2), :, np.arange(2), :]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=2047, shape=(2, 3), dtype=float32, numpy=\n",
       "array([[ 0.01286735, -0.0107625 , -0.00567085],\n",
       "       [-0.01690738,  0.01197595,  0.00869031]], dtype=float32)>"
      ]
     },
     "execution_count": 71,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "second_grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 50,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=1093, shape=(2, 3), dtype=float32, numpy=\n",
       "array([[ 4., 10.,  6.],\n",
       "       [12., 16.,  2.]], dtype=float32)>"
      ]
     },
     "execution_count": 50,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=0, shape=(2, 2), dtype=float32, numpy=\n",
       "array([[2., 5.],\n",
       "       [6., 8.]], dtype=float32)>"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=360, shape=(2,), dtype=float32, numpy=array([ 29., 100.], dtype=float32)>"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=374, shape=(2, 2), dtype=float32, numpy=\n",
       "array([[ 4., 10.],\n",
       "       [12., 16.]], dtype=float32)>"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "grad"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "tape.gradient(grad, x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'PathExplainerTF' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-13-6fa0041395b4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0mnum_samples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m10\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mexplainer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPathExplainerTF\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m eg_single = explainer._single_attribution(x[0], np.zeros((1, 2)).astype(np.float32),\n\u001b[1;32m      4\u001b[0m                             \u001b[0mcurrent_alphas\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinspace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1.0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_samples\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mastype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfloat32\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m                             \u001b[0mnum_samples\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnum_samples\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mNameError\u001b[0m: name 'PathExplainerTF' is not defined"
     ]
    }
   ],
   "source": [
    "num_samples = 10\n",
    "explainer = PathExplainerTF(model)\n",
    "eg_single = explainer._single_attribution(x[0], np.zeros((1, 2)).astype(np.float32),\n",
    "                            current_alphas=np.linspace(0.0, 1.0, num_samples).astype(np.float32),\n",
    "                            num_samples=num_samples, batch_size=50,\n",
    "                            use_expectation=False, output_index=0, input_indices=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.5297718, 5.702946 ], dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eg_single"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=117, shape=(2,), dtype=float32, numpy=array([0.5297718, 5.7029467], dtype=float32)>"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model.weights[0][:, 0] * x[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=125, shape=(1, 1), dtype=float32, numpy=array([[6.2327185]], dtype=float32)>"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model(x[0:1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.5297718, 5.702946 ], dtype=float32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "eg_single"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 105,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<tf.Tensor: id=1076, shape=(2, 3), dtype=float32, numpy=\n",
       "array([[2., 5., 3.],\n",
       "       [6., 8., 1.]], dtype=float32)>"
      ]
     },
     "execution_count": 105,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
